# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
from pathlib import Path

import numpy as np
import torch
from .models import build_ACT_model, build_CNNMLP_model, build_HATACT_model
import IPython
e = IPython.embed

class AttrDict(dict):
    def __getattr__(self, key):
        return self.get(key, False)

    def __setattr__(self, key, value):
        self[key] = value

def get_args_parser():
    parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
    parser.add_argument('--lr', default=1e-4, type=float) # will be overridden
    parser.add_argument('--loss', action='store', type=str, help='loss', required=False)
    parser.add_argument('--lr_backbone', default=1e-5, type=float) # will be overridden
    parser.add_argument('--batch_size', default=2, type=int) # not used
    parser.add_argument('--weight_decay', default=1e-4, type=float)
    parser.add_argument('--epochs', default=300, type=int) # not used
    parser.add_argument('--lr_drop', default=200, type=int) # not used
    parser.add_argument('--clip_max_norm', default=0.1, type=float, # not used
                        help='gradient clipping max norm')

    # Model parameters
    # * Backbone
    parser.add_argument('--backbone', default='resnet18', type=str, # will be overridden
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--dilation', action='store_true',
                        help="If true, we replace stride with dilation in the last convolutional block (DC5)")
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                        help="Type of positional embedding to use on top of the image features")
    parser.add_argument('--camera_names', default=[], type=list, # will be overridden
                        help="A list of camera names")

    # * Transformer
    parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
                        help="Number of encoding layers in the transformer")
    parser.add_argument('--dec_layers', default=6, type=int, # will be overridden
                        help="Number of decoding layers in the transformer")
    parser.add_argument('--dim_feedforward', default=2048, type=int, # will be overridden
                        help="Intermediate size of the feedforward layers in the transformer blocks")
    parser.add_argument('--hidden_dim', default=256, type=int, # will be overridden
                        help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--dropout', default=0.1, type=float,
                        help="Dropout applied in the transformer")
    parser.add_argument('--nheads', default=8, type=int, # will be overridden
                        help="Number of attention heads inside the transformer's attentions")
    parser.add_argument('--num_queries', default=400, type=int, # will be overridden
                        help="Number of query slots")
    parser.add_argument('--pre_norm', action='store_true')

    # * Segmentation
    parser.add_argument('--masks', action='store_true',
                        help="Train segmentation head if the flag is provided")

    # repeat args in imitate_episodes just to avoid error. Will not be used
    parser.add_argument('--resume', action='store_true')
    parser.add_argument('--resume_ckpt_path', action='store', type=str, help='resume_ckpt_path', required=False)
    parser.add_argument('--project', action='store', type=str, help='project', required=False)
    parser.add_argument('--name', action='store', type=str, help='name', required=False)
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--train_eval', action='store_true')
    parser.add_argument('--wandb', action='store_true')
    parser.add_argument('--onscreen_render', action='store_true')
    parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=False)
    parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=False)
    parser.add_argument('--task_name', action='store', type=str, help='task_name', required=False)
    parser.add_argument('--seed', action='store', type=int, help='seed', required=False)
    parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=False)
    # parser.add_argument('--dec_layers', action='store', type=int, help='num_dec_layers', required=False)
    parser.add_argument('--num_blocks', action='store', type=int, help='num_blocks', required=False)
    parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
    parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
    parser.add_argument('--temporal_agg', action='store_true')
    parser.add_argument('--dual_latent', action='store_true', help='use latent z', required=False)
    parser.add_argument('--single_latent', action='store_true', help='use latent z', required=False)


    return parser


def build_ACT_model_and_optimizer(args_override):
    # parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
    # args = parser.parse_args()

    project = False
    name = ''
    eval = False
    resume = False
    resume_ckpt_path = None
    train_eval = False
    wandb_ = False
    onscreen_render = False
    save_video = False
    ckpt_dir = 'runs/test'
    policy_class = 'HATACT'
    task_name = 'test'
    batch_size = 8
    seed = 0
    num_epochs = 5000  
    loss = 'l1'
    lr = 1e-4
    dual_latent = False
    single_latent = False
    dec_layers = 7
    num_blocks = 1
    kl_weight = 1
    chunk_size = 50
    hidden_dim = 512
    dim_feedforward = 2048
    temporal_agg = False
    notes = ''

    position_embedding = 'sine'
    camera_names = []
    enc_layers = 4
    nheads = 8
    num_queries = 100
    dropout = 0.1
    pre_norm = False
    masks = False
    backbone = 'resnet18'
    dilation = False
    lr_backbone = 1e-5
    weight_decay = 1e-4
    epochs = 300
    lr_drop = 200
    clip_max_norm = 0.1

    
    args = {'project': project, 'name': name, 'eval': eval, 'resume': resume, 'resume_ckpt_path': resume_ckpt_path, 'train_eval': train_eval, 'wandb_': wandb_, 'onscreen_render': onscreen_render, 'save_video': save_video, 'ckpt_dir': ckpt_dir, 'policy_class': policy_class, 'task_name': task_name, 'batch_size': batch_size, 'seed': seed, 'num_epochs': num_epochs, 'loss': loss, 'lr': lr, 'dual_latent': dual_latent, 'single_latent': single_latent, 'dec_layers': dec_layers, 'num_blocks': num_blocks, 'kl_weight': kl_weight, 'chunk_size': chunk_size, 'hidden_dim': hidden_dim, 'dim_feedforward': dim_feedforward, 'temporal_agg': temporal_agg, 'notes': notes, 'position_embedding': position_embedding, 'camera_names': camera_names, 'enc_layers': enc_layers, 'nheads': nheads, 'num_queries': num_queries, 'dropout': dropout, 'pre_norm': pre_norm, 'masks': masks, 'backbone': backbone, 'dilation': dilation, 'lr_backbone': lr_backbone, 'weight_decay': weight_decay, 'epochs': epochs, 'lr_drop': lr_drop, 'clip_max_norm': clip_max_norm}

    for k, v in args_override.items():
        # setattr(args, k, v)
        args[k] = v

    args = AttrDict(args)

    model = build_ACT_model(args)
    model.cuda()

    param_dicts = [
        {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                  weight_decay=args.weight_decay)

    return model, optimizer

def build_HATACT_model_and_optimizer(args_override):
    # parser = argparse.ArgumentParser('HAACT training and evaluation script', parents=[get_args_parser()])
    # args = parser.parse_args()

    project = False
    name = ''
    eval = False
    resume = False
    resume_ckpt_path = None
    train_eval = False
    wandb_ = False
    onscreen_render = False
    save_video = False
    ckpt_dir = 'runs/test'
    policy_class = 'HATACT'
    task_name = 'test'
    batch_size = 8
    seed = 0
    num_epochs = 5000  
    loss = 'l1'
    lr = 1e-4
    dual_latent = False
    single_latent = False
    dec_layers = 7
    num_blocks = 1
    kl_weight = 1
    chunk_size = 50
    hidden_dim = 512
    dim_feedforward = 2048
    temporal_agg = False
    notes = ''

    position_embedding = 'sine'
    camera_names = []
    enc_layers = 4
    nheads = 8
    num_queries = 100
    dropout = 0.1
    pre_norm = False
    masks = False
    backbone = 'resnet18'
    dilation = False
    lr_backbone = 1e-5
    weight_decay = 1e-4
    epochs = 300
    lr_drop = 200
    clip_max_norm = 0.1

    
    args = {'project': project, 'name': name, 'eval': eval, 'resume': resume, 'resume_ckpt_path': resume_ckpt_path, 'train_eval': train_eval, 'wandb_': wandb_, 'onscreen_render': onscreen_render, 'save_video': save_video, 'ckpt_dir': ckpt_dir, 'policy_class': policy_class, 'task_name': task_name, 'batch_size': batch_size, 'seed': seed, 'num_epochs': num_epochs, 'loss': loss, 'lr': lr, 'dual_latent': dual_latent, 'single_latent': single_latent, 'dec_layers': dec_layers, 'num_blocks': num_blocks, 'kl_weight': kl_weight, 'chunk_size': chunk_size, 'hidden_dim': hidden_dim, 'dim_feedforward': dim_feedforward, 'temporal_agg': temporal_agg, 'notes': notes, 'position_embedding': position_embedding, 'camera_names': camera_names, 'enc_layers': enc_layers, 'nheads': nheads, 'num_queries': num_queries, 'dropout': dropout, 'pre_norm': pre_norm, 'masks': masks, 'backbone': backbone, 'dilation': dilation, 'lr_backbone': lr_backbone, 'weight_decay': weight_decay, 'epochs': epochs, 'lr_drop': lr_drop, 'clip_max_norm': clip_max_norm}

    for k, v in args_override.items():
        # setattr(args, k, v)
        args[k] = v

    args = AttrDict(args)

    model = build_HATACT_model(args)
    model.cuda()

    param_dicts = [
        {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                  weight_decay=args.weight_decay)

    return model, optimizer


def build_CNNMLP_model_and_optimizer(args_override):
    # parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
    args = {}

    for k, v in args_override.items():
        setattr(args, k, v)

    model = build_CNNMLP_model(args)
    model.cuda()

    param_dicts = [
        {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                  weight_decay=args.weight_decay)

    return model, optimizer

